from pytorch_lightning.callbacks import TQDMProgressBar


class TaskProgressBar(TQDMProgressBar):
    def get_metrics(self, trainer, pl_module):
        # 获取基础指标（如 loss, epoch 等）
        items = super().get_metrics(trainer, pl_module)

        # 添加您的自定义任务指标
        items["task"] = f"{pl_module.current_task}"
        items["seen_tasks"] = f"{len(pl_module.seen_tasks)}"

        return items
